package com.github.xzb617.client.route.scg.ribbon;

import com.github.xzb617.client.route.scg.constant.RouteKey;
import com.github.xzb617.client.route.scg.ribbon.support.FilterResult;
import com.github.xzb617.client.route.scg.ribbon.support.ServerHelper;
import com.netflix.loadbalancer.*;

import io.jmnarloch.spring.cloud.ribbon.api.RibbonFilterContext;
import io.jmnarloch.spring.cloud.ribbon.predicate.MetadataAwarePredicate;
import io.jmnarloch.spring.cloud.ribbon.rule.DiscoveryEnabledRule;
import io.jmnarloch.spring.cloud.ribbon.support.RibbonFilterContextHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;

import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Ribbon 自定义负载均衡规则
 * <p>
 *     支持动态流量染色：
 *     （1）金丝雀发布
 *     （2）蓝绿部署
 *     （3）全链路灰度
 * </p>
 */
public class GoalkeeperRoundRobinRule extends DiscoveryEnabledRule {

    @Value("${spring.application.name}")
    private String applicationName;

    private final Logger LOGGER = LoggerFactory.getLogger(GoalkeeperRoundRobinRule.class);
    private AtomicInteger nextServerPositionCounter;

    public GoalkeeperRoundRobinRule() {
        super(new MetadataAwarePredicate());
        this.nextServerPositionCounter = new AtomicInteger(0);
    }

    @Override
    public AbstractServerPredicate getPredicate() {
        return null;
    }

    @Override
    public Server choose(Object key) {
        RibbonFilterContext ribbonCtx = RibbonFilterContextHolder.getCurrentContext();
        Map<String, String> attributes = ribbonCtx.getAttributes();

        // 获取负载均衡器
        BaseLoadBalancer loadBalancer = (BaseLoadBalancer) this.getLoadBalancer();
        if (loadBalancer == null) {
            return null;
        }

        // caller 和 callee 是否匹配
        boolean isMatchedCallerAndCallee = this.compareCallerAndCallee(attributes, loadBalancer);

        boolean isRuleRequestWithinWeight = false;
        if (isMatchedCallerAndCallee) {
            isRuleRequestWithinWeight = ServerHelper.isRuleRequestWithinWeight(attributes);
        }

        Server targetServer = null;
        int count = 0;
        while (true) {
            if (targetServer==null && count++ < 10) {
                List<Server> servers = loadBalancer.getReachableServers();
                List<Server> allServers = loadBalancer.getAllServers();

                if (servers.size()==0 || allServers.size()==0) {
                    LOGGER.warn("No up servers available from load balancer: " + loadBalancer);
                    return null;
                }

                // 匹配规则
                List<Server> alternativeServers;
                String ruleValue = ribbonCtx.get(RouteKey.ROUTE_RULE_KEY);
                FilterResult filterResult = ServerHelper.filterRule(allServers, ruleValue);

                // 必须 caller 和 callee 都匹配成功
                if (isMatchedCallerAndCallee) {
                    // 匹配成功后，再判断是否在流量比重内
                    if (isRuleRequestWithinWeight) {
                        alternativeServers = filterResult.getTargetServers();
                    } else {
                        alternativeServers = filterResult.getOthersServers();
                    }
                } else {
                    // 匹配失败，则默认所有服务可达
                    alternativeServers = allServers;
                }

                // 判断最终是否有可路由的服务
                int alternativeCount = alternativeServers.size();
                if (alternativeCount == 0)  {
                    LOGGER.warn("No up servers available from load balancer: " + loadBalancer);
                    return null;
                }

                int nextServerIndex = this.incrementAndGetModulo(alternativeCount);
                targetServer = (Server) alternativeServers.get(nextServerIndex);
                if (targetServer == null) {
                    Thread.yield();
                } else {
                    if (targetServer.isAlive() && targetServer.isReadyToServe()) {
                        return targetServer;
                    }

                    targetServer = null;
                }
                continue;
            }

            if (count >= 10) {
                LOGGER.warn("No available alive servers after 10 tries from load balancer: " + loadBalancer);
            }

            return targetServer;
        }
    }

    @Override
    public void setLoadBalancer(ILoadBalancer lb) {
        super.setLoadBalancer(lb);
    }


    private int incrementAndGetModulo(int modulo) {
        int current;
        int next;
        do {
            current = this.nextServerPositionCounter.get();
            next = (current + 1) % modulo;
        } while(!this.nextServerPositionCounter.compareAndSet(current, next));

        return next;
    }

    private boolean compareCallerAndCallee(Map<String, String> attributes, BaseLoadBalancer loadBalancer) {
        // 获取Ribbon中的 caller、callee
        String callerInRibbonCtx = attributes.get(RouteKey.ROUTE_RULE_KEY + "-caller");
        String calleeInRibbonCtx = attributes.get(RouteKey.ROUTE_RULE_KEY + "-callee");

        // 如果为null，则重新默认为 * ，即所有服务
        if (callerInRibbonCtx == null) {
            callerInRibbonCtx = "*";
        }
        if (calleeInRibbonCtx == null) {
            calleeInRibbonCtx = "*";
        }

        // 所有服务，直接通过
        if ("*".equals(callerInRibbonCtx) && "*".equals(calleeInRibbonCtx)) {
            return true;
        }

        // 不同情况下判断是否一致
        String caller = this.applicationName;
        String callee = loadBalancer.getClientConfig().getClientName();
        if ("*".equals(callerInRibbonCtx)) {
            // 比较 callee 是否一致
            return calleeInRibbonCtx.equals(callee);
        } else if ("*".equals(calleeInRibbonCtx)) {
            // 比较 caller 是否一致
            return callerInRibbonCtx.equals(caller);
        } else {
            // 比较两者
            return callerInRibbonCtx.equals(caller) && calleeInRibbonCtx.equals(callee);
        }
    }

}
